{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d8513771-36ac-4d03-b890-35108bce2211",
   "metadata": {},
   "source": [
    "### Loading and processing IMDB movie review dataset\n",
    "In this example, we will load the IMDB dataset from Hugging Face, \n",
    "use `torchdata.nodes` to process it and generate training batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eb3b507c-2ad1-410d-a834-6847182de684",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import BertTokenizer, BertForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "089f1126-7125-4274-9d71-5c949ccc7bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import default_collate, RandomSampler, SequentialSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2afac7d9-3d66-4195-8647-dc7034d306f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load IMDB dataset from huggingface datasets and select the \"train\" split\n",
    "dataset = load_dataset(\"imdb\", streaming=False)\n",
    "dataset = dataset[\"train\"]\n",
    "# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n",
    "# Please refer to the migration guide here https://pytorch.org/data/main/migrate_to_nodes_from_utils.html\n",
    "# to migrate from torch.utils.data to torchdata.nodes\n",
    "\n",
    "sampler = RandomSampler(dataset)\n",
    "# Use a standard bert tokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "# Now we can set up some torchdata.nodes to create our pre-proc pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09e08a47-573c-4d32-9a02-36cd8150db60",
   "metadata": {},
   "source": [
    "All torchdata.nodes.BaseNode implementations are Iterators.\n",
    "MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n",
    "Under the hood, MapStyleWrapper just does:\n",
    "```python\n",
    "node = IterableWrapper(sampler)\n",
    "node = Mapper(node, map_fn=dataset.__getitem__)  # You can parallelize this with ParallelMapper\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "02af5479-ee69-41d8-ab2d-bf154b84bc15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n",
    "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
    "\n",
    "# Now we want to transform the raw inputs. We can just use another Mapper with\n",
    "# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n",
    "# threads (or processes) to parallelize this work and have it run in the background\n",
    "max_len = 512\n",
    "batch_size = 2\n",
    "def bert_transform(item):\n",
    "    encoding = tokenizer.encode_plus(\n",
    "        item[\"text\"],\n",
    "        add_special_tokens=True,\n",
    "        max_length=max_len,\n",
    "        padding=\"max_length\",\n",
    "        truncation=True,\n",
    "        return_attention_mask=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    return {\n",
    "        \"input_ids\": encoding[\"input_ids\"].flatten(),\n",
    "        \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n",
    "        \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n",
    "    }\n",
    "node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n",
    "\n",
    "# Next we batch the inputs, and then apply a collate_fn with another Mapper\n",
    "# to stack the tensors between. We use torch.utils.data.default_collate for this\n",
    "node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n",
    "node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n",
    "\n",
    "# we can optionally apply pin_memory to the batches\n",
    "if torch.cuda.is_available():\n",
    "    node = PinMemory(node)\n",
    "\n",
    "# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
    "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
    "loader = Loader(node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "60fd54f3-62ef-47aa-a790-853cb4899f13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[ 101, 1045, 2572,  ..., 2143, 2000,  102],\n",
      "        [ 101, 2004, 1037,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],\n",
      "        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n"
     ]
    }
   ],
   "source": [
    "# Inspect a batch\n",
    "batch = next(iter(loader))\n",
    "print(batch)\n",
    "# In a batch we get three keys, as defined in the method `bert_transform`.\n",
    "# Since the batch size is 2, two samples are stacked together for each key."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
